{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Homework_Simple Neuro Network"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This model gains accuracy of 98.2% by introducing 2 hidden layers with relu() activation function. The primitive values of weights are obtained using random initialization while biases with zeros. The model keeps classification function and optimizer unchanged from homework script as softmax and GradientDescentOptimizer respectively, also learning rate remains unaltered."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/yupt/.local/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-1-e7d8bd774d53>:2: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "WARNING:tensorflow:From /home/yupt/.local/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please write your own downloading logic.\n",
      "WARNING:tensorflow:From /home/yupt/.local/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
      "WARNING:tensorflow:From /home/yupt/.local/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From /home/yupt/.local/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.one_hot on tensors.\n",
      "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From /home/yupt/.local/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n"
     ]
    }
   ],
   "source": [
    "import tensorflow.examples.tutorials.mnist.input_data as input_data\n",
    "mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "# load data\n",
    "trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels\n",
    "X = tf.placeholder(\"float\", [None, 784])\n",
    "Y = tf.placeholder(\"float\", [None, 10])\n",
    "\n",
    "#define weights function\n",
    "def init_weights(shape):\n",
    "    return tf.Variable(tf.random_normal(shape, stddev=0.05))\n",
    "  \n",
    "#initialize weights and biases\n",
    "w_h1 = init_weights([784,600])\n",
    "w_h2 = init_weights([600,200])\n",
    "w_o = init_weights([200,10])\n",
    "b_h1 = tf.Variable(tf.zeros([600]))\n",
    "b_h2 = tf.Variable(tf.zeros([200]))\n",
    "b_o = tf.Variable(tf.zeros([10]))\n",
    "\n",
    "\n",
    "                           \n",
    "\n",
    "#define model function\n",
    "def model(X, w_h1, w_h2, w_o):\n",
    "    \n",
    "    #The first hidden layer\n",
    "    h1 = tf.nn.relu(tf.matmul(X,w_h1)+b_h1)\n",
    "    \n",
    "    #The second hidden layer\n",
    "    h2 = tf.nn.relu(tf.matmul(h1, w_h2)+b_h2)\n",
    "\n",
    "    return tf.matmul(h2, w_o)+b_o #This is the value of logits\n",
    "\n",
    "y_ = model(X, w_h1, w_h2, w_o)\n",
    "\n",
    "#define cost and train_op\n",
    "cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = y_, labels = Y))\n",
    "train_op = tf.train.GradientDescentOptimizer(0.5).minimize(cost)\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9821\n"
     ]
    }
   ],
   "source": [
    "#Train the model\n",
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "for _ in range(3000):\n",
    "    batch_xs, batch_ys = mnist.train.next_batch(100)\n",
    "    sess.run(train_op, feed_dict = {X:batch_xs, Y:batch_ys})\n",
    "    \n",
    "#Test the trained model\n",
    "correct_prediction = tf.equal(tf.argmax(y_,1), tf.argmax(Y,1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
    "print(sess.run(accuracy, feed_dict = {X:mnist.test.images, Y:mnist.test.labels}))\n",
    "\n",
    "sess.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
